{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "1ec627e5-8b8d-4c76-bc2c-519af5b32d20",
   "metadata": {},
   "source": [
    "# Instructions\n",
    "\n",
    "In this tutorial, we will perform multi-label classification using an ECG-FM model finetuned on the [MIMIC-IV-ECG v1.0 dataset](https://physionet.org/content/mimic-iv-ecg/1.0/). It outlines the data and model loading, as well as inference, same-sample prediction aggregation, and visualizations for embeddings and saliency maps.\n",
    "\n",
    "ECG-FM was developed in collaboration with the [fairseq_signals](https://github.com/Jwoo5/fairseq-signals) framework, which implements a collection of deep learning methods for ECG analysis.\n",
    "\n",
    "This is segment the ECG into inputs of 5 s and perform a label-specific aggregation of the predictions from each sample\n",
    "\n",
    "This document serves largely as a quickstart introduction. Much of this functionality is also available via the [fairseq-signals scripts](https://github.com/bowang-lab/ECG-FM/blob/main/notebooks/infer_cli.ipynb), as well the [ECG-FM scripts](https://github.com/bowang-lab/ECG-FM/tree/main/scripts)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4d4a9804-4444-4aaa-af00-8c9869cbcc5a",
   "metadata": {},
   "source": [
    "## Installation\n",
    "\n",
    "Begin by cloning [fairseq_signals](https://github.com/Jwoo5/fairseq-signals) and refer to the installation section in the top-level README. For example, the following commands are sufficient at the present moment:\n",
    "```\n",
    "# Creating `fairseq` environment:\n",
    "conda create --name fairseq python=3.10.6\n",
    "source activate fairseq\n",
    "git clone https://github.com/Jwoo5/fairseq-signals\n",
    "cd fairseq-signals\n",
    "python3 -m pip install pip==24.0\n",
    "python3 -m pip install -e .\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5992565-e416-4103-a0e7-e2b8a09893f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# You may require the following imports depending on what functionality you run\n",
    "!pip install huggingface-hub\n",
    "!pip install pandas\n",
    "!pip install ecg-transform==0.1.3\n",
    "!pip install umap-learn\n",
    "!pip install plotly"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 102,
   "id": "1f34c08a-bb4c-4182-a604-e4bc0db0e46b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "root = os.path.dirname(os.getcwd())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ec114e98-ad66-46c3-875f-088a8786781e",
   "metadata": {},
   "source": [
    "## Download checkpoints\n",
    "\n",
    "Checkpoints are available on [HuggingFace](https://huggingface.co/wanglab/ecg-fm). The finetuned model be downloaded using the following command:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "614f439f-5825-4614-a105-39353c36b5cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from huggingface_hub import hf_hub_download\n",
    "\n",
    "_ = hf_hub_download(\n",
    "    repo_id='wanglab/ecg-fm',\n",
    "    filename='mimic_iv_ecg_finetuned.yaml',\n",
    "    local_dir=os.path.join(root, 'ckpts'),\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8c2fd0dc-b8f6-48d1-b56d-994cd5aab3e0",
   "metadata": {},
   "source": [
    "# Inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "197b620a-f7da-4fa8-acb2-e1a63a1138fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "ckpt_path: str = os.path.join(root, 'ckpts/mimic_iv_ecg_finetuned.pt')\n",
    "assert os.path.isfile(ckpt_path)\n",
    "\n",
    "device: str = 'cuda'\n",
    "batch_size: int = 16\n",
    "num_workers: int = 0\n",
    "\n",
    "extract_saliency: bool = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c13e3c2-4dd6-4ea8-a916-3df84778c123",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Any, List\n",
    "\n",
    "def to_list(obj: Any) -> List[Any]:\n",
    "    if isinstance(obj, list):\n",
    "        return obj\n",
    "\n",
    "    if isinstance(obj, (np.ndarray, set, dict)):\n",
    "        return list(obj)\n",
    "\n",
    "    return [obj]\n",
    "\n",
    "file_paths = [\n",
    "    os.path.join(root, 'data/code_15/org', file) for file in \\\n",
    "    os.listdir(os.path.join(root, 'data/code_15/org'))\n",
    "]\n",
    "file_paths = to_list(file_paths)\n",
    "file_paths"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c761b64c-0a48-488b-86d0-130418ade807",
   "metadata": {},
   "source": [
    "## Prepare data\n",
    "\n",
    "To simplify this tutorial, we have processed a sample of 10 ECGs (14 5s segments) from the [CODE-15% v1.0.0 dataset](https://zenodo.org/records/4916206/) using our [end-to-end data preprocessing pipeline](https://github.com/Jwoo5/fairseq-signals/tree/master/scripts/preprocess/ecg). Its README is also helpful if looking to perform inference using your own dataset, where there are already preprocessing scripts implemented for several public datasets."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "87a9feac-feb1-49aa-a960-69c7190400f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import List\n",
    "from itertools import chain\n",
    "\n",
    "from scipy.io import loadmat\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "import torch\n",
    "from torch.utils.data import Dataset\n",
    "from torch.utils.data.dataloader import DataLoader\n",
    "\n",
    "from ecg_transform.inp import ECGInput, ECGInputSchema\n",
    "from ecg_transform.sample import ECGMetadata, ECGSample\n",
    "from ecg_transform.t.base import ECGTransform\n",
    "from ecg_transform.t.common import (\n",
    "    HandleConstantLeads,\n",
    "    LinearResample,\n",
    "    ReorderLeads,\n",
    ")\n",
    "from ecg_transform.t.scale import Standardize\n",
    "from ecg_transform.t.cut import SegmentNonoverlapping\n",
    "\n",
    "class ECGFMDataset(Dataset):\n",
    "    def __init__(\n",
    "        self,\n",
    "        schema,\n",
    "        transforms,\n",
    "        file_paths,\n",
    "    ):\n",
    "        self.schema = schema\n",
    "        self.transforms = transforms\n",
    "        self.file_paths = file_paths\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.file_paths)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        mat = loadmat(self.file_paths[idx])\n",
    "        metadata = ECGMetadata(\n",
    "            sample_rate=int(mat['org_sample_rate'][0, 0]),\n",
    "            num_samples=mat['feats'].shape[1],\n",
    "            lead_names=['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6'],\n",
    "            unit=None,\n",
    "            input_start=0,\n",
    "            input_end=mat['feats'].shape[1],\n",
    "        )\n",
    "        metadata.file = self.file_paths[idx]\n",
    "        inp = ECGInput(mat['feats'], metadata)\n",
    "        sample = ECGSample(\n",
    "            inp,\n",
    "            self.schema,\n",
    "            self.transforms,\n",
    "        )\n",
    "        source = torch.from_numpy(sample.out).float()\n",
    "\n",
    "        return source, inp\n",
    "\n",
    "def collate_fn(inps):\n",
    "    sample_ids = list(\n",
    "        chain.from_iterable([[inp[1]]*inp[0].shape[0] for inp in inps])\n",
    "    )\n",
    "    return torch.concatenate([inp[0] for inp in inps]), sample_ids\n",
    "\n",
    "def file_paths_to_loader(\n",
    "    file_paths: List[str],\n",
    "    schema: ECGInputSchema,\n",
    "    transforms: List[ECGTransform],\n",
    "    batch_size = 64,\n",
    "    num_workers = 7,\n",
    "):\n",
    "    dataset = ECGFMDataset(\n",
    "        schema,\n",
    "        transforms,\n",
    "        file_paths,\n",
    "    )\n",
    "\n",
    "    return DataLoader(\n",
    "        dataset,\n",
    "        batch_size=batch_size,\n",
    "        num_workers=num_workers,\n",
    "        pin_memory=True,\n",
    "        sampler=None,\n",
    "        shuffle=False,\n",
    "        collate_fn=collate_fn,\n",
    "        drop_last=False,\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "85f8c81f-de69-4af3-be49-ec9e5632b39a",
   "metadata": {},
   "outputs": [],
   "source": [
    "ECG_FM_LEAD_ORDER = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']\n",
    "SAMPLE_RATE = 500\n",
    "N_SAMPLES = SAMPLE_RATE*5\n",
    "\n",
    "label_def = pd.read_csv(\n",
    "    os.path.join(root, 'data/mimic_iv_ecg/labels/label_def.csv'),\n",
    "     index_col='name',\n",
    ")\n",
    "label_names = label_def.index.to_list()\n",
    "label_names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "082bd08b-832e-4f58-9d56-0e069ce2b710",
   "metadata": {},
   "outputs": [],
   "source": [
    "AGG_METHODS = {\n",
    "    'Poor data quality': 'max',\n",
    "    'Sinus rhythm': 'mean',\n",
    "    'Premature ventricular contraction': 'max',\n",
    "    'Tachycardia': 'mean',\n",
    "    'Ventricular tachycardia': 'max',\n",
    "    'Supraventricular tachycardia with aberrancy': 'max',\n",
    "    'Bradycardia': 'mean',\n",
    "    'Infarction': 'mean',\n",
    "    'Atrioventricular block': 'mean',\n",
    "    'Right bundle branch block': 'mean',\n",
    "    'Left bundle branch block': 'mean',\n",
    "    'Electronic pacemaker': 'max',\n",
    "    'Atrial fibrillation': 'mean',\n",
    "    'Atrial flutter': 'mean',\n",
    "    'Accessory pathway conduction': 'mean',\n",
    "    '1st degree atrioventricular block': 'mean',\n",
    "    'Bifascicular block': 'mean',\n",
    "}\n",
    "\n",
    "ECG_FM_SCHEMA = ECGInputSchema(\n",
    "    sample_rate=SAMPLE_RATE,\n",
    "    expected_lead_order=ECG_FM_LEAD_ORDER,\n",
    "    required_num_samples=N_SAMPLES,\n",
    ")\n",
    "\n",
    "ECG_FM_TRANSFORMS = [\n",
    "    ReorderLeads(\n",
    "        expected_order=ECG_FM_LEAD_ORDER,\n",
    "        missing_lead_strategy='raise',\n",
    "    ),\n",
    "    LinearResample(desired_sample_rate=SAMPLE_RATE),\n",
    "    HandleConstantLeads(strategy='zero'),\n",
    "    Standardize(),\n",
    "    SegmentNonoverlapping(segment_length=N_SAMPLES),\n",
    "]\n",
    "\n",
    "loader = file_paths_to_loader(\n",
    "    file_paths,\n",
    "    ECG_FM_SCHEMA,\n",
    "    ECG_FM_TRANSFORMS,\n",
    "    batch_size=batch_size,\n",
    "    num_workers=num_workers,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d23b74a1-2306-4c93-8e80-0bbdce958edf",
   "metadata": {},
   "source": [
    "## Load model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4742edde-0191-4220-9933-a02a565b4f15",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Dict, List, Optional, Tuple, Type, Union\n",
    "from collections import OrderedDict\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "import torch\n",
    "\n",
    "from fairseq_signals.models import build_model_from_checkpoint\n",
    "from fairseq_signals.models.classification.ecg_transformer_classifier import (\n",
    "    ECGTransformerClassificationModel\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0871cf80-c4b8-4c91-993b-9d33b1190241",
   "metadata": {},
   "outputs": [],
   "source": [
    "model: ECGTransformerClassificationModel = build_model_from_checkpoint(\n",
    "    checkpoint_path=ckpt_path\n",
    ")\n",
    "\n",
    "# Forcibly enable the return of attention weights for saliency maps\n",
    "if extract_saliency:\n",
    "    model.encoder.encoder.need_weights = extract_saliency\n",
    "    for layer in model.encoder.encoder.layers:\n",
    "        layer.need_weights = extract_saliency\n",
    "\n",
    "model.eval()\n",
    "model.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e1bbab44-7039-475c-8868-ad2396b5c858",
   "metadata": {},
   "source": [
    "## Infer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7ef175b-838f-41da-bf04-f17622b5063d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def encoder_out_to_emb(x, device='cpu'):\n",
    "    # fairseq_signals/models/classification/ecg_transformer_classifier.py\n",
    "    return torch.div(x.sum(dim=1), (x != 0).sum(dim=1))\n",
    "\n",
    "def infer(\n",
    "    model,\n",
    "    loader,\n",
    "    device,\n",
    "    extract_saliency: bool = True,\n",
    "):\n",
    "    inps = []\n",
    "    sources = []\n",
    "    logits = []\n",
    "    embs = []\n",
    "    saliency = []\n",
    "    file_names = []\n",
    "    for source, inp in loader:\n",
    "        source = source.to(device)\n",
    "        out = model(source=source)\n",
    "        inps.extend(inp)\n",
    "        sources.append(source)\n",
    "        logits.append(out['out'])\n",
    "        embs.append(encoder_out_to_emb(out['encoder_out']))\n",
    "        saliency.append(out['saliency'])\n",
    "        file_names.extend([i.meta.file for i in inp])\n",
    "\n",
    "    # Handle predictions\n",
    "    pred = torch.sigmoid(torch.concatenate(logits)).detach().cpu().numpy()\n",
    "    pred = pd.DataFrame(pred, columns=label_names, index=file_names)\n",
    "\n",
    "    results = {\n",
    "        'inps': inps,\n",
    "        'sources': torch.concatenate(sources).detach().cpu().numpy(),\n",
    "        'embs': torch.concatenate(embs).detach().cpu().numpy(),\n",
    "        'pred': pred,\n",
    "    }\n",
    "\n",
    "    # Handle saliency\n",
    "    if extract_saliency:\n",
    "        saliency = torch.concatenate(saliency).detach()\n",
    "        attn = saliency[:, -1] # Consider only the last attention layer\n",
    "        results['attn_max'] = attn.max(axis=2).values.squeeze().cpu().detach().numpy()\n",
    "\n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd3fd83c-94fc-45dc-beec-dbc7f5d4cde3",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = infer(model, loader, device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "272b7e73-0ce6-48d0-a711-9bf6e6d5da50",
   "metadata": {},
   "outputs": [],
   "source": [
    "pred = results['pred']\n",
    "print(f\"Number of 5 s segment predictions: {len(pred)}.\")\n",
    "pred"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "74cd45d2-e8e4-4cb5-ba60-582af6fe706a",
   "metadata": {},
   "source": [
    "# Result handling"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5e4abebf-02f1-471a-91ce-9c108d37a1fa",
   "metadata": {},
   "source": [
    "## Prediction aggregation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93ff4c35-31f3-4c7d-b4f3-1af4f84dc24c",
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_agg = pred.groupby(pred.index).agg(AGG_METHODS).astype(float)\n",
    "print(f\"Number of sample-aggregated predictions: {len(pred_agg)}.\")\n",
    "pred_agg"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d2c68597-a013-4428-8f68-ef47e22ec610",
   "metadata": {},
   "source": [
    "## Visualizing embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f667db06-7946-4ac3-bb34-3e5969f1b104",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import umap\n",
    "\n",
    "reducer = umap.UMAP(n_neighbors=3, min_dist=0.1, n_components=2, random_state=42)\n",
    "embs_2d = reducer.fit_transform(results['embs'])\n",
    "\n",
    "# Generate a color map\n",
    "sample_identifier = pred.index.to_series()\n",
    "unique_values = sample_identifier.unique()\n",
    "colors = plt.colormaps.get_cmap('tab20')  # Use a colormap with enough distinct colors\n",
    "color_map = {val: colors(i) for i, val in enumerate(unique_values)}\n",
    "colored_items = sample_identifier.map(color_map)\n",
    "\n",
    "# Plot the 2D UMAP visualization\n",
    "plt.scatter(\n",
    "    embs_2d[:, 0],\n",
    "    embs_2d[:, 1],\n",
    "    s=30,\n",
    "    alpha=0.9,\n",
    "    color=colored_items.values,\n",
    "    rasterized=True,\n",
    ")\n",
    "\n",
    "# Remove axis labels and grid\n",
    "plt.xticks([])\n",
    "plt.yticks([])\n",
    "plt.grid(False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7887ee2d-4b7a-43f2-aac0-51c2b0a5cd30",
   "metadata": {},
   "source": [
    "More fitting when visualizing many embeddings:\n",
    "```\n",
    "import matplotlib.pyplot as plt\n",
    "import umap\n",
    "reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, n_components=2, random_state=42) # Better when there are more embeddings\n",
    "\n",
    "# Plot the 2D UMAP visualization\n",
    "plt.scatter(\n",
    "    embs_2d[:, 0],\n",
    "    embs_2d[:, 1],\n",
    "    s=1,\n",
    "    alpha=0.9,\n",
    "    rasterized=True,\n",
    ")\n",
    "\n",
    "# Remove axis labels and grid\n",
    "plt.xticks([])\n",
    "plt.yticks([])\n",
    "plt.grid(False)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d2e0d7a7-12e2-4bed-a92d-52146ad541e8",
   "metadata": {},
   "source": [
    "## Saliency maps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe12a9ae-7904-4b14-8ecd-b8f5c9ae21f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Tuple\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "from scipy.ndimage import map_coordinates\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import plotly.graph_objects as go"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89f20ce6-df0a-49ef-b632-6311baa54fea",
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_idx = 0\n",
    "\n",
    "saliency_lead = 'II'\n",
    "lead_ind = ECG_FM_LEAD_ORDER.index(saliency_lead)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c3dc192-9989-4c64-ad8f-026cabb4d735",
   "metadata": {},
   "outputs": [],
   "source": [
    "signal = results['sources'][sample_idx, lead_ind]\n",
    "attn_max = results['attn_max'][sample_idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49aa404f-2187-4649-8a9d-6b6e50168048",
   "metadata": {},
   "outputs": [],
   "source": [
    "def blend_colors_hex(start_color: str, end_color: str, activations: np.ndarray) -> np.ndarray:\n",
    "    \"\"\"\n",
    "    Blends between two colors based on an array of blend factors.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    start_color : str\n",
    "        Hexadecimal color code for the start color.\n",
    "    end_color : str\n",
    "        Hexadecimal color code for the end color.\n",
    "    activations : np.ndarray\n",
    "        An array of blend factors where 0 corresponds to the start color and 1 to the end color.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    np.ndarray\n",
    "        An array of hexadecimal color codes resulting from the blends.\n",
    "\n",
    "    Raises\n",
    "    ------\n",
    "    ValueError\n",
    "        If any of the input blend factors are not within the range [0, 1].\n",
    "    \"\"\"\n",
    "    if np.any((activations < 0) | (activations > 1)):\n",
    "        raise ValueError(\"All blend factors must be between 0 and 1.\")\n",
    "\n",
    "    # Convert hexadecimal to RGB\n",
    "    def hex_to_rgb(hex_color: str) -> Tuple[int]:\n",
    "        return tuple(int(hex_color[i: i+2], 16) for i in (1, 3, 5))\n",
    "\n",
    "    # Get RGB tuples\n",
    "    start_rgb = np.array(hex_to_rgb(start_color))\n",
    "    end_rgb = np.array(hex_to_rgb(end_color))\n",
    "\n",
    "    # Blend RGB values\n",
    "    blended_rgb = np.outer(1 - activations, start_rgb) + np.outer(activations, end_rgb)\n",
    "\n",
    "    # Convert blended RGB back to hex codes\n",
    "    return blended_rgb / 255\n",
    "\n",
    "def colored_line_segments(data: np.ndarray, colors: np.ndarray, ax=None, **kwargs):\n",
    "    \"\"\"\n",
    "    Plots line segments based on the provided data points, with each segment\n",
    "    colored according to the corresponding color specification in `colors`.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    data : np.ndarray\n",
    "        Array of y-values for the line segments.\n",
    "    colors : np.ndarray\n",
    "        Array of colors, each color applied to the corresponding line segment\n",
    "        between points i and i+1.\n",
    "\n",
    "    Raises\n",
    "    ------\n",
    "    ValueError\n",
    "        If the `colors` array does not have exactly one less element than the `data` array,\n",
    "        as each segment needs a unique color.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    None\n",
    "    \"\"\"\n",
    "    if len(colors) != len(data) - 1:\n",
    "        raise ValueError(\"Colors array must have one fewer elements than data array.\")\n",
    "\n",
    "    if ax is None:\n",
    "        for i in range(len(data) - 1):\n",
    "            plt.plot([i, i + 1], [data[i], data[i + 1]], color=colors[i], **kwargs)\n",
    "    else:\n",
    "        for i in range(len(data) - 1):\n",
    "            ax.plot([i, i + 1], [data[i], data[i + 1]], color=colors[i], **kwargs)\n",
    "\n",
    "def prep_saliency_values(attn_max, target_sample_length):\n",
    "    # Resample to original sample size\n",
    "    new_dims = [\n",
    "        np.linspace(0, original_length-1, new_length) \\\n",
    "        for original_length, new_length in \\\n",
    "        zip(attn_max.shape, (target_sample_length - 1,))\n",
    "    ]\n",
    "    coords = np.meshgrid(*new_dims, indexing='ij')\n",
    "    attn_max = map_coordinates(attn_max, coords)\n",
    "\n",
    "    # Min-max normalization\n",
    "    attn_max = attn_max - attn_max.min()\n",
    "    attn_max = attn_max/attn_max.max()\n",
    "\n",
    "    return attn_max\n",
    "\n",
    "saliency_prepped = prep_saliency_values(\n",
    "    attn_max.ravel(),\n",
    "    attn_max.shape[0] * signal.shape[-1],\n",
    ")\n",
    "saliency_colors = blend_colors_hex('#0047AB', '#DC143C', saliency_prepped)\n",
    "saliency_colors = (saliency_colors*255).astype(int)\n",
    "\n",
    "# Define a custom colorscale from blue to red\n",
    "colorscale = [[0, 'blue'], [1, 'red']]  # Simple gradient from blue to red\n",
    "\n",
    "time = np.arange(2500)\n",
    "\n",
    "# Create the figure\n",
    "fig = go.Figure()\n",
    "y_values = signal[:-1]\n",
    "for i in range(len(y_values) - 1):\n",
    "    fig.add_trace(\n",
    "        go.Scatter(\n",
    "            x=[time[i], time[i + 1]],\n",
    "            y=[y_values[i], y_values[i + 1]],\n",
    "            mode='lines',\n",
    "            line=dict(color='rgb({},{},{})'.format(*saliency_colors[i]), width=2),\n",
    "            showlegend=False  # Avoid cluttering the legend\n",
    "        )\n",
    "    )\n",
    "fig['layout']['yaxis'].update(autorange = True)\n",
    "fig['layout']['xaxis'].update(autorange = True)\n",
    "\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a1adde1-8e23-455b-a0a6-34b9cd4c3162",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "fairseq",
   "language": "python",
   "name": "fairseq"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
